"""Apply pre-computed transforms to put a stat map in template space.

Assumptions:
    File hierarchy laid out per description in beau_anat_to_template.py

How to use it:
    python beau_stats_to_template.py {subj_id} {stat_map_path}

If the stat map is in BRIK/HEAD format, omit the file extension.
E.g.,: just "my_stat_map+orig".

Example:

python beau_stats_to_template.py shellgame_15 shellgame_15/shellgame_15.results/statmap+orig
"""

from __future__ import print_function
import os
import subprocess
import sys


def epi_to_mni(anat_to_template_warp_path, anat_to_template_affine_path,
               anat_to_epi_itk_path, stat_map_path, template_path,
               output_path):
    """Transform an EPI to MNI-space.

    EPI-to-anat and anat-to-MNI transformations must exist.
    """
    assert os.path.isfile(anat_to_template_warp_path)
    assert os.path.isfile(anat_to_template_affine_path)
    assert os.path.isfile(anat_to_epi_itk_path)
    assert os.path.isfile(stat_map_path)

    cmd = [
        'antsApplyTransforms',
        '-i', stat_map_path,
        '-d', '3',
        '-o', output_path,
        '-r', template_path,
        '-t', anat_to_template_warp_path,
        '-t', anat_to_template_affine_path,
        '-t', "[{0},1]".format(anat_to_epi_itk_path)
    ]
    subprocess.call(cmd)


def convert_transform_afni_to_itk(oned_path, output_path=None):
    """Convert an AFNI 1D 'Form 1A' affine transform to ITK text for ANTS.

    AFNI transforms in 'Form 1A' contain four vectors arranged in a single
    line of text in the order:
        u11 u12 u13 v1 u21 u22 u23 v2 u31 u32 u33 v3

    Whereas ITK transforms are in the order:
        u11 u12 u13 u21 u22 u23 u31 u32 u33 v1 v2 v3

    Where:
        u1x: rotate
        u2x: scale
        u3x: shear
        v:   translate

    This function creates a new text file with the proper boilerplate and puts
    the numbers in the correct order for ITK/ANTS.
    """
    with open(oned_path, 'r') as oned_file:
        oned_lines = oned_file.readlines()
    oned_lines = [l for l in oned_lines if not l.startswith('#')]
    oned_values = [x for x in oned_lines[0].strip().split(' ') if x != '']
    print(oned_values)
    itk_values = [
        oned_values[0], oned_values[1], oned_values[2],
        oned_values[4], oned_values[5], oned_values[6],
        oned_values[8], oned_values[9], oned_values[10],
        oned_values[3], oned_values[7], oned_values[11]
    ]
    if output_path is None:
        output_path = os.path.splitext(oned_path)[0] + '_itk.txt'
    with open(output_path, 'w') as itk_file:
        itk_string = (
            "#Insight Transform File V1.0\n"
            "#Transform 0\n"
            "Transform: AffineTransform_double_3_3\n"
            "Parameters: {0}\n\n"
            "FixedParameters: 0 0 0\n".format(' '.join(itk_values))
        )
        print(itk_string, file=itk_file)


def set_metadata(map_path):
    cmd = "3drefit -view tlrc -space TLRC {0}".format(map_path)
    subprocess.call(cmd, shell=True)


def gzfix(path):
    if not os.path.exists(path):
        path = path + '.gz'
    return path


if __name__ == "__main__":
    subj_id = sys.argv[1]
    stat_map_path = sys.argv[2]

    print("Converting anat-to-epi transform from AFNI to ITK format...")

    anat_to_epi_afni_path = os.path.join(
        subj_id,
        "{0}.BrainExtractionBrain_al_mat.aff12.1D".format(subj_id)
    )
    anat_to_epi_itk_path = os.path.join(
        subj_id,
        "{0}.anat_to_epi_itk.txt".format(subj_id)
    )
    convert_transform_afni_to_itk(
        anat_to_epi_afni_path,
        anat_to_epi_itk_path
    )

    if not (stat_map_path.endswith('.nii') or stat_map_path.endswith('.nii.gz')):
        print("Converting stat map to NIFTI format...")
        cmd = "3dAFNItoNIFTI {0}".format(stat_map_path)
        subprocess.call(cmd, shell=True)
        stat_map_path = stat_map_path.split('+')[0] + '.nii'
        # Because 3dAFNItoNIFTI saves in the current working directory
        if os.path.exists(os.path.split(stat_map_path)[1] + '.gz'):
            stat_map_path = stat_map_path + '.gz'
        os.rename(
            os.path.split(stat_map_path)[1],
            stat_map_path
        )

    output_path = stat_map_path[0:-4] + '_template.nii'

    anat_to_template_warp_path = os.path.join(
        subj_id,
        "{0}.1Warp.nii.gz".format(subj_id)
    )
    anat_to_template_affine_path = os.path.join(
        subj_id,
        "{0}.0GenericAffine.mat".format(subj_id)
    )
    template_path = "TT_N27_SurfVol.nii"
    epi_to_mni(
        anat_to_template_warp_path,
        anat_to_template_affine_path,
        anat_to_epi_itk_path,
        stat_map_path,
        template_path,
        output_path
    )

    set_metadata(output_path)

    print("All done.")
